import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import adept_envs # type: ignore
import gym
import math
import cv2
import numpy as np
from PIL import Image
import os
import torchvision.transforms as T
from vip import load_vip
import matplotlib.pyplot as plt
import pickle
import time
torch.set_printoptions(edgeitems=10, linewidth=500)
# Basic global variables
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print("Device is", device, flush=True)
vip = load_vip()
vip.eval()
vip = vip.to(device)
transforms = T.Compose([T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor()]) # ToTensor() divides by 255


class Robotic_Environment: # This will only be used for evaluating a trained model
    # Creates the entire robotic environment on pybullet
    def __init__(self, video_resolution, gaussian_noise, camera_number, reset_information, in_hand_eval):

        env = gym.make('kitchen_relax-v1')
        self.env = env.env
        self.video_frames = [] # These are the frames of video saved for evaluation
        self.video_resolution = video_resolution
        self.env.reset()
        if(gaussian_noise):
            mean = 0  # Mean of the Gaussian noise
            std_dev =0.03  # Standard deviation of the Gaussian noise
            self.env.sim.data.qpos[:] = reset_information[0] * (1 + np.random.normal(mean, std_dev, reset_information[0].shape))
            self.env.sim.data.qvel[:] = reset_information[1] * (1 + np.random.normal(mean, std_dev, reset_information[1].shape))
        else:
            self.env.sim.data.qpos[:] = reset_information[0]
            self.env.sim.data.qvel[:] = reset_information[1]
        self.env.sim.forward() # The environment is setup

    def step(self, action):
        
        self.env.step(np.array(action)) # Execute some action
        curr_frame = self.env.render(mode='rgb_array') # Capture image
        rgb_array = np.array(curr_frame)
        rgb_array = Image.fromarray(rgb_array)
        rgb_array = np.array(rgb_array)
        bgr_array = cv2.cvtColor(rgb_array, cv2.COLOR_RGB2BGR)
        bgr_array = cv2.resize(bgr_array, self.video_resolution)
        self.video_frames.append(bgr_array)

    def get_current_state(self, space): # This is the state in the format specified as input
        if(space == "joint_space"):
            return (self.env._get_obs()).tolist()
        elif(space == "both"):
            # Image embedding
            curr_frame = self.env.render(mode='rgb_array') # Capture image
            rgb_array = np.array(curr_frame)
            rgb_array = Image.fromarray(rgb_array)
            rgb_array = np.array(rgb_array)
            preprocessed_image = transforms(Image.fromarray(rgb_array.astype(np.uint8))).reshape(-1, 3, 224, 224)
            preprocessed_image = preprocessed_image.to(device)
            with torch.no_grad():
                subgoal_embedding = vip(preprocessed_image * 255.0)
            current_state = subgoal_embedding.cpu().tolist()[0]
            # Joint space
            non_fixed_current_joint_state = (self.env._get_obs()).tolist()
            # Concatenate image + joint
            return current_state + non_fixed_current_joint_state
        elif(space == "image_embedding"):
            curr_frame = self.env.render(mode='rgb_array') # Capture image
            rgb_array = np.array(curr_frame)
            rgb_array = Image.fromarray(rgb_array)
            rgb_array = np.array(rgb_array)
            preprocessed_image = transforms(Image.fromarray(rgb_array.astype(np.uint8))).reshape(-1, 3, 224, 224)
            preprocessed_image = preprocessed_image.to(device)
            with torch.no_grad():
                subgoal_embedding = vip(preprocessed_image * 255.0)
            return subgoal_embedding.cpu().tolist()[0]

    def save_video(self, video_filename, video_filename_in_hand):
        video_fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        video_out = cv2.VideoWriter(video_filename, video_fourcc, 30.0, self.video_resolution)
        for i in range(0 , len(self.video_frames),4 ): # fast forward 4X
            frame = self.video_frames[i]
            video_out.write(frame)
        video_out.release()

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):  # Assuming 5000 is the maximum length of any trajectory
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, 1, d_model)  # Shape: (max_len, 1, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # Shape: (max_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))  # Shape: (d_model/2,)
        pe[:, 0, 0::2] = torch.sin(position * div_term)  # Apply sine to even indices
        pe[:, 0, 1::2] = torch.cos(position * div_term)  # Apply cosine to odd indices
        self.register_buffer('pe', pe)  # Register as buffer to avoid updating during training

    def forward(self, x, timestamp=-1):
        if isinstance(timestamp, int) and timestamp == -1:
            x = x + self.pe[:x.size(0), :]  # Shape: (seq_length, batch_size, d_model)
            return self.dropout(x)
        else:
            timestamp = timestamp.long().to(self.pe.device)  # Shape: (batch_size,)
            positions = self.pe[timestamp]  # Shape: (batch_size, 1, d_model)
            positions = positions.squeeze(1).unsqueeze(0)  # Shape: (1, batch_size, d_model)
            x = x + positions  # Broadcasting addition (1, batch_size, d_model) + (1, batch_size, d_model)
            return self.dropout(x)

class CustomTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, activation="relu",dropout=0.0):

        super(CustomTransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.activation = nn.ReLU() if activation == "relu" else nn.GELU()
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src , attention_mask):
        attn_output, _ = self.self_attn(src, src, src, key_padding_mask = attention_mask) # Self-attention
        src = self.norm1(src + attn_output)  # Residual connection
        ff_output = self.linear2(self.activation(self.linear1(src))) # Feedforward network
        src = self.norm2(src + ff_output)  # Residual connection
        return src

class CustomTransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, activation="relu",dropout=0.0):
        super(CustomTransformerDecoderLayer, self).__init__()
        self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.activation = nn.ReLU() if activation == "relu" else nn.GELU()
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.dropout = nn.Dropout(dropout)
        self.attn_weights = None  # To store attention weights

    def forward(self, tgt, memory, attention_mask): # memory here is the encoded subgoals from the encoder
        attn_output, attn_weights = self.cross_attn(tgt, memory, memory , key_padding_mask = attention_mask ,average_attn_weights=False) # cross attention
        self.attn_weights = attn_weights # These are the attention weights
        tgt2 = self.norm1(tgt + attn_output)  # Residual connection
        ff_output = self.activation(self.linear2(self.activation(self.linear1(tgt2)))) # removed layer norm here and added activation, this performs better
        return ff_output

class TransformerPolicy_Custom(nn.Module): # This is transformer policy with customised encoder decoder
    def __init__(self,subgoal_dimension,state_dimension,output_dimension,subgoal_seq_length,nhead,num_encoder_layers,num_decoder_layers, dim_feedforward,dropout, activation="relu"):
        super(TransformerPolicy_Custom, self).__init__()
        self.subgoal_seq_length = subgoal_seq_length
        self.subgoal_dimension = subgoal_dimension
        self.state_dimension = state_dimension  # subgoal_dim and state_dim must be ideally equal for cross attention
        self.output_dimension = output_dimension # This is action
        self.subgoal_pos_encoder = PositionalEncoding(subgoal_dimension, dropout)
        self.transformer_encoder_layers = nn.ModuleList([ CustomTransformerEncoderLayer( d_model=subgoal_dimension,  nhead=nhead,  dim_feedforward=dim_feedforward, activation=activation, dropout=dropout
            ) for _ in range(num_encoder_layers)
         ])

        self.state_pos_encoder = PositionalEncoding(state_dimension, dropout)
        self.transformer_decoder_layers = nn.ModuleList([ CustomTransformerDecoderLayer( d_model=state_dimension,  nhead=nhead,  dim_feedforward=dim_feedforward, activation=activation, dropout=dropout
            ) for _ in range(num_decoder_layers)
        ])
        self.output_layer = nn.Linear(state_dimension, output_dimension)
        self.attention_weights = [] # These are the attention weights from all the layers of the decoder
        for _ in range(num_decoder_layers):
            self.attention_weights.append([])

    def forward(self, subgoals, current_state, attention_mask, timestamp, inference=False): # attention mask given here is for padding

        subgoals = subgoals.permute(1, 0, 2)  # Shape: (subgoal_seq_length, batch_size, subgoal_dimension)
        subgoals = self.subgoal_pos_encoder(subgoals)  # Add positional encoding, -1 means no timestamp given as input
        for encoder_layer in self.transformer_encoder_layers:
            subgoals = encoder_layer(subgoals, attention_mask) # Self attention of subgoals

        current_state = current_state.unsqueeze(0)  # Shape: (1, batch_size, state_dimension)
        current_state = self.state_pos_encoder(current_state, timestamp)  # Add positional encoding based on timestamp

        current_layer_index=0
        for decoder_layer in self.transformer_decoder_layers:
            current_state = decoder_layer(current_state, subgoals, attention_mask)
            if(inference):
                self.attention_weights[current_layer_index].append(((decoder_layer.attn_weights).squeeze(1).squeeze(0)).cpu().tolist())
            current_layer_index+=1

        decoder_output = current_state.squeeze(0)  # Shape: (batch_size, state_dimension)
        actions = self.output_layer(decoder_output)  # Shape: (batch_size, output_dimension)
        return actions
    
    def reset_attention_weights(self, num_decoder_layers):
        self.attention_weights = [] # These are the attention weights from all the layers of the decoder
        for _ in range(num_decoder_layers):
            self.attention_weights.append([])

class TrajectoryDataset(Dataset): # Dateset for Behavioural cloning
    def __init__(self, Trajectory_directories, base_directory , state_space, subgoal_format, subgoal_directory_path, camera_number,subgoal_seq_length, action_chunking):
        self.Trajectory_directories = Trajectory_directories # List of all the directories 
        self.base_directory= base_directory
        self.state_space = state_space
        self.subgoal_format = subgoal_format
        self.subgoal_directory_path = subgoal_directory_path
        self.camera_number = camera_number
        self.subgoal_seq_length=subgoal_seq_length # Length of subgoals, pad with 0 if not exactly equal
        self.action_chunking = action_chunking
        self.trajectories = self._load_trajectories()

    def _read_csv(self, file_path, directory): # file here is a pkl file
        with open(file_path, 'rb') as f: # Read the pickel file
            data_dict = pickle.load(f)

        observations = data_dict['observations']  # Shape: (244, 60)
        actions = data_dict['actions']  # Shape: (244, 9)
        data = []
        for i in range(observations.shape[0]):
            observation = observations[i]
            action = actions[i]
            row = list(observation) + [0., 0., 0.] + [i] + list(action) # Create the row: 60 observation columns + 3 buffer columns + 1 timestamp + 9 action columns
            data.append(row)
        # For every state append the embedding of image to the row
        video_path = f"{self.base_directory}/{directory}/camera_{camera_number}.avi"
        cap = cv2.VideoCapture(video_path)
        for i in range(len(data)):
            ret, frame = cap.read()  # ret is a boolean indicating success, frame is the image
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            preprocessed_image = transforms(Image.fromarray(frame.astype(np.uint8))).reshape(-1, 3, 224, 224)
            preprocessed_image = preprocessed_image.to(device)
            with torch.no_grad():
                subgoal_embedding = vip(preprocessed_image * 255.0)
            data[i].extend(subgoal_embedding.cpu().tolist()[0])
        cap.release()
        # data is 60(joint) + 3(task/buffer) + 1(time) + 9(action) + 1024(image embedding)
        subgoals_directory = f"{self.base_directory}/{directory}/{self.subgoal_directory_path}"
        files = os.listdir(subgoals_directory)
        png_files = [f for f in files if f.endswith('.png')]
        numbers = [int(f.replace('.png', '')) for f in png_files]
        list_of_subgoals = sorted(numbers) # This is the sorted list of all the subgoals for some trajectory
        list_of_subgoals.pop(0)
        list_of_actual_subgoals = []
        attention_mask = []
        for subgoal_index in list_of_subgoals:
            attention_mask.append(False) # False means this is not padding mask
            if(self.subgoal_format == "joint_space"):
               list_of_actual_subgoals.append(data[subgoal_index][:60])
            elif(self.subgoal_format == "both"):
                list_of_actual_subgoals.append(data[subgoal_index][73:1097] + data[subgoal_index][:60]) # image embedding
            elif(self.subgoal_format == "image_embedding"): # get the image embedding of subgoal_index frame in the video
                list_of_actual_subgoals.append(data[subgoal_index][73:1097]) # 20+1024
        if len(list_of_actual_subgoals) < self.subgoal_seq_length: # Zero padding
            padding = []
            for i in range(len(list_of_actual_subgoals[0])):
                padding.append(0)
            for i in range(self.subgoal_seq_length - len(list_of_actual_subgoals)):
                list_of_actual_subgoals.append(padding)
                attention_mask.append(True) # True means that this is padding and this will not be involved in attention

        output = [] # list of subgoals , state, action, attentionmask , timestamp
        for i in range(len(data)):
            append_to_output = [list_of_actual_subgoals] # list of subgoals
            if(self.state_space == "joint_space"):
                state = data[i][0:60]
            elif(self.state_space == "both"):
                state = data[i][73:1097] + data[i][0:60]
            elif(self.state_space == "image_embedding"):
                state = data[i][73:1097]
            append_to_output.append(state) # state
            action = data[i][64:73]
            for j in range(i+1 , i+self.action_chunking):
                if (j >= len(data)):
                    action+= [0.,0.,0.,0.,0.,0.,0.,0.,0.]
                else:
                    action+= data[j][64:73]
            append_to_output.append(action) # action
            append_to_output.append(attention_mask) # Attention Mask
            append_to_output.append(i) # Timestamp
            output.append(append_to_output)

        return output

    def _load_trajectories(self):
        trajectories = []
        for directory in self.Trajectory_directories:
            base_directory = f"{self.base_directory}/{directory}"
            file_path = f"{base_directory}/data.pkl"
            trajectory_data = self._read_csv(file_path, directory)
            for i in range(len(trajectory_data)):
                subgoal_state_action_pair = trajectory_data[i]
                trajectories.append(subgoal_state_action_pair)
        return trajectories

    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx): # This gives the exact subgoal, state, action , mask , timestamp tuple
        trajectory_data = self.trajectories[idx]
        subgoals = torch.tensor(trajectory_data[0] ,  dtype=torch.float32)
        state = torch.tensor(trajectory_data[1] ,  dtype=torch.float32)
        action = torch.tensor(trajectory_data[2] ,  dtype=torch.float32)
        attention_mask = torch.tensor(trajectory_data[3], dtype=torch.bool)
        timestamp = torch.tensor(trajectory_data[4], dtype=torch.float32)
        return (subgoals , state , action, attention_mask , timestamp)

def find_largest_number(file_path): # Takes in a directory which contains files of the form number.mp4 and find the largest numbered file inside it
    with open(file_path, 'r') as file:
        lines = file.readlines()
    last_line = lines[-1].strip()
    first_word = last_line.split()[0]
    first_word_int = int(first_word)
    return first_word_int

if __name__ == '__main__':
    # Parameters
    train = True
    eval = True
    output_dimension = 9 # Action will always be 9 dimensional 7 dimension joint angles + 2 dimension gripper

    state_space = "both" # "joint_space", "both", "image_embedding"
    if(state_space == "joint_space"):
        state_dimension =60
    elif(state_space == "both"):
        state_dimension =1024+60 # image , joint
    elif(state_space == "image_embedding"):
        state_dimension = 1024 # See how many dimension is the image embedding????

    subgoal_format = "both" # "joint_space", "both", "image_embedding"
    if(subgoal_format == "joint_space"):
        subgoal_dimension= 60
    elif(subgoal_format == "both"):
        subgoal_dimension = 1024+60
    elif(subgoal_format == "image_embedding"):
        subgoal_dimension = 1024 # See how many dimension is the image embedding????

    Trajectory_directories = ['1.1', '1.2', '1.3', '1.4', '1.5', '2.1', '2.2', '2.3', '2.4', '2.5', '3.1', '3.2', '3.3', '3.4', '3.5', '4.1', '4.2', '4.3', '4.4', '4.5', '5.1', '5.2', '5.3', '5.4', '5.5', '6.1', '6.2', '6.3', '6.4', '6.5', '7.1', '7.2', '7.3', '7.4', '7.5', '8.1', '8.2', '8.3', '8.4', '8.5', '9.1', '9.2', '9.3', '9.4', '9.5', '10.1', '10.2', '10.3', '10.4', '10.5', '11.1', '11.2', '11.3', '11.4', '11.5', '12.1', '12.2', '12.3', '12.4', '12.5', '13.1', '13.2', '13.3', '13.4', '13.5', '14.1', '14.2', '14.3', '14.4', '14.5', '15.1', '15.2', '15.3', '15.4', '15.5', '16.1', '16.2', '16.3', '16.4', '16.5', '17.1', '17.2', '17.3', '17.4', '17.5', '18.1', '18.2', '18.3', '18.4', '18.5', '19.1', '19.2', '19.3', '19.4', '19.5', '20.1', '20.2', '20.3', '20.4', '20.5', '21.1', '21.2', '21.3', '21.4', '21.5', '22.1', '22.2', '22.3', '22.4', '22.5', '23.1', '23.2', '23.3', '23.4', '23.5', '24.1', '24.2', '24.3', '24.4', '24.5', '25.1', '25.2', '25.3', '25.4', '25.5']
    list_of_subgoals_directory_to_eval = ['1.1', '1.2', '1.3', '1.4', '1.5', '1.6', '1.7', '1.8', '1.9', '1.10', '2.1', '2.2', '2.3', '2.4', '2.5', '2.6', '2.7', '2.8', '2.9', '2.10', '3.1', '3.2', '3.3', '3.4', '3.5', '3.6', '3.7', '3.8', '3.9', '3.10', '4.1', '4.2', '4.3', '4.4', '4.5', '4.6', '4.7', '4.8', '4.9', '4.10', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7', '5.8', '5.9', '5.10', '6.1', '6.2', '6.3', '6.4', '6.5', '6.6', '6.7', '6.8', '6.9', '6.10', '7.1', '7.2', '7.3', '7.4', '7.5', '7.6', '7.7', '7.8', '7.9', '7.10', '8.1', '8.2', '8.3', '8.4', '8.5', '8.7', '8.8', '8.9', '8.10', '9.1', '9.2', '9.3', '9.4', '9.5', '9.6', '9.7', '9.8', '9.9', '9.10', '10.1', '10.2', '10.3', '10.4', '10.5', '10.6', '10.7', '10.8', '10.9', '10.10', '11.1', '11.2', '11.3', '11.4', '11.5', '11.6', '11.7', '11.8', '11.9', '11.10', '12.1', '12.2', '12.3', '12.4', '12.5', '12.6', '12.7', '12.8', '12.9', '12.10', '13.1', '13.2', '13.3', '13.4', '13.5', '13.6', '13.7', '13.8', '13.9', '13.10', '14.1', '14.2', '14.3', '14.4', '14.5', '14.6', '14.7', '14.8', '14.9', '14.10', '15.1', '15.2', '15.3', '15.4', '15.5', '15.6', '15.7', '15.8', '15.9', '15.10', '16.1', '16.2', '16.3', '16.4', '16.5', '16.6', '16.7', '16.8', '16.9', '16.10', '17.1', '17.2', '17.3', '17.4', '17.5', '17.6', '17.7', '17.8', '17.9', '17.10', '18.1', '18.2', '18.3', '18.4', '18.5', '18.6', '18.7', '18.8', '18.9', '18.10', '19.1', '19.2', '19.3', '19.4', '19.5', '19.6', '19.7', '19.8', '19.9', '19.10', '20.1', '20.2', '20.3', '20.4', '20.5', '20.6', '20.7', '20.8', '20.9', '20.10', '21.1', '21.2', '21.3', '21.4', '21.5', '21.6', '21.7', '21.8', '21.9', '21.10', '22.1', '22.2', '22.3', '22.4', '22.5', '22.6', '22.7', '22.8', '22.9', '22.10', '23.1', '23.2', '23.3', '23.4', '23.5', '23.6', '23.7', '23.8', '23.9', '23.10', '24.1', '24.2', '24.3', '24.4', '24.5', '24.6', '24.7', '24.8', '24.9', '24.10', '25.1', '25.2', '25.3', '25.4', '25.5', '25.6', '25.7', '25.8', '25.9', '25.10']
    total_number_of_iterations=1 # number of iterations per task (helpful with gaussian noise)
    num_epochs = 1000 # number of epochs on the training dataset
    
    lr =  0.0003 # learning rate
    dropout=0.1
    subgoal_seq_length = 20 # This is the maximum number of subgoals going inside the decoder, rest all will be zero padded. This is done to ensure lots of data in the same batch
    nhead=4 # number of attention heads. state dimension must be divisible by attention heads
    num_encoder_layers=1 # number of encoder layers
    num_decoder_layers=3 # number of decoder layers
    dim_feedforward=1024 # dimension of feedforward network
    gaussian_noise = False # gaussian noise on the start state
    action_chunking = 10 # Action chunking = 1 means only 1 step prediction
    temporal_ensemble = 0 # Weight given for combining actions
    in_hand_eval = False # Get in hand camera video or not
    camera_number = 2 # Camera for subgoals

    output_dimension*=action_chunking
    subgoal_directory_path = f"decomposed_frames/mininterval_18/divisions_1/gamma_0.08/camera_{camera_number}" # Can change 0.08 to something else if required
    saving_formatter = str(find_largest_number("./Parameter_mappings.txt")+1)

    with open('./Parameter_mappings.txt', 'a') as file:
        file.write(f'{saving_formatter}        : state_space_{state_space}_num_epochs_{num_epochs}_lr_{lr}_dropout_{dropout}_subgoal_seq_length_{subgoal_seq_length}_nhead_{nhead}_num_encoder_layers_{num_encoder_layers}_num_decoder_layers_{num_decoder_layers}_dim_feedforward_{dim_feedforward}_subgoal_{subgoal_format}_camera_{camera_number}_gaussian_noise_{gaussian_noise}_action_chunking_{action_chunking}_temporal_ensemble_{temporal_ensemble}_Training_directory_{Trajectory_directories}\n')  # Add a newline character to separate lines
    model_dump_file_path = f"./Trained_Models/{saving_formatter}.pth"
    base_directory = f"./../../Data_Franka_Kitchen"
    model = TransformerPolicy_Custom(subgoal_dimension,state_dimension,output_dimension,subgoal_seq_length, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout).to(device)
    print(model, flush=True)
    total_params = sum(p.numel() for p in model.parameters())
    print("Total number of parameters in the neural network is: ", total_params, flush=True)

    if(train):
        train_start_time = time.time()
        trajectory_dataset = TrajectoryDataset(Trajectory_directories, base_directory , state_space, subgoal_format, subgoal_directory_path, camera_number, subgoal_seq_length, action_chunking)
        data_loader = DataLoader(trajectory_dataset, batch_size=512, shuffle=True )
        optimizer = optim.Adam(model.parameters(), lr=lr)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) # cosine decay of learning rate
        loss_function = nn.MSELoss()
        # Supervised Learning Loop
        for epoch in range(num_epochs):
            model.train()  # Set the model to training mode
            running_loss = 0.0  # Initialize running loss for the epoch
            num_batches = 0     # Initialize batch counter
            for batch_idx, (subgoals, states, actions, attention_mask, timestamp) in enumerate(data_loader):
                subgoals = subgoals.to(device)        # Shape: (batch_size, subgoal_seq_length, subgoal_dim)
                states = states.to(device)            # Shape: (batch_size, state_dim)
                actions = actions.to(device)          # Shape: (batch_size, action_dim)
                attention_mask = attention_mask.to(device)      # Shape: (batch_size, subgoal_seq_length)
                timestamp = timestamp.to(device)           # Shape: (batch_size,)

                optimizer.zero_grad()
                predicted_actions = model(subgoals,states,attention_mask, timestamp)
                loss = loss_function(predicted_actions, actions)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()  # Accumulate loss
                num_batches += 1             # Increment batch counter

            scheduler.step()

            if(epoch%50 == 0): # Print loss every 50 epochs
                current_lr = optimizer.param_groups[0]['lr'] # Current learning rate
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/num_batches}, Learning Rate: {current_lr}", flush=True)

        torch.save(model.state_dict(), model_dump_file_path)
        print(f"Model saved to {model_dump_file_path}", flush=True)
        
        print("Time taken to train the model is ", (time.time() - train_start_time)/3600 , " hrs")

    if(eval):
        model.load_state_dict(torch.load(model_dump_file_path))
        model.eval()  # Set the model to evaluation mode
        length_of_trajectories_during_inference = {}
        reset_info_of_trajectories_during_inference = {} # information about the trajectory to infer
        for directory in list_of_subgoals_directory_to_eval:
            file_path = f"{base_directory}/{directory}/data.pkl" # pkl file path
            with open(file_path, 'rb') as f: # Read the pickel file
                data_dict = pickle.load(f)
                length_of_trajectories_during_inference[directory] = data_dict['observations'].shape[0]
                reset_info_of_trajectories_during_inference[directory] = (data_dict['init_qpos'] , data_dict['init_qvel'])

        for iteration_number in range(1,total_number_of_iterations+1,1): # Number of times to evaluate a single trajectory, to get the evaluation metrics
            for directory_for_subgoals in list_of_subgoals_directory_to_eval: # These are all the trajectories to get subgoals from and evaluate
                model.reset_attention_weights(num_decoder_layers) # Reset the attention weights 
                video_resolution = (224, 224) # This is for resolution for evaluation videos
                reset_information  = reset_info_of_trajectories_during_inference[directory_for_subgoals]
                robot_env = Robotic_Environment(video_resolution, gaussian_noise, camera_number, reset_information, in_hand_eval)

                def robot_inference(directory_for_subgoals): # Function to actually evaluate the neural network
                    max_steps = length_of_trajectories_during_inference[directory_for_subgoals]
                    subgoals_directory = f"{base_directory}/{directory_for_subgoals}/{subgoal_directory_path}"
                    files = os.listdir(subgoals_directory)
                    png_files = [f for f in files if f.endswith('.png')]
                    numbers = [int(f.replace('.png', '')) for f in png_files]
                    list_of_subgoals = sorted(numbers) # This is the sorted list of all the subgoals for some trajectory
                    list_of_subgoals.pop(0) # Dont want initial state to be a subgoal
                    actual_subgoals = [] # This is either 8 or 4 or 1024 dimensional
                    if(subgoal_format == "joint_space"):
                        file_path = f"{base_directory}/{directory_for_subgoals}/data.pkl"
                        with open(file_path, 'rb') as f: # Read the pickel file
                            data_dict = pickle.load(f)
                        observations = data_dict['observations']  # Shape: (244, 60)
                        actions = data_dict['actions']  # Shape: (244, 9)
                        data = []
                        for i in range(observations.shape[0]):
                            observation = observations[i]
                            action = actions[i]
                            row = list(observation) + [0., 0., 0.] + [i] + list(action) # Create the row: 60 observation columns + 3 buffer columns + 1 timestamp + 9 action columns
                            data.append(row) # data now in csv format
                        for subgoal_index in list_of_subgoals:
                            actual_subgoals.append(data[subgoal_index][:60])
                    elif(subgoal_format == "both"):
                        video_path = f"{base_directory}/{directory_for_subgoals}/camera_{camera_number}.avi"
                        cap = cv2.VideoCapture(video_path)
                        for subgoal_index in list_of_subgoals:
                            cap.set(cv2.CAP_PROP_POS_FRAMES, subgoal_index)
                            ret, frame = cap.read()
                            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                            preprocessed_image = transforms(Image.fromarray(frame.astype(np.uint8))).reshape(-1, 3, 224, 224)
                            preprocessed_image = preprocessed_image.to(device)
                            with torch.no_grad():
                                subgoal_embedding = vip(preprocessed_image * 255.0)
                            actual_subgoals.append(subgoal_embedding.cpu().tolist()[0])
                        cap.release()
                        file_path = f"{base_directory}/{directory_for_subgoals}/data.pkl"
                        with open(file_path, 'rb') as f: # Read the pickel file
                            data_dict = pickle.load(f)
                        observations = data_dict['observations']  # Shape: (244, 60)
                        actions = data_dict['actions']  # Shape: (244, 9)
                        data = []
                        for i in range(observations.shape[0]):
                            observation = observations[i]
                            action = actions[i]
                            row = list(observation) + [0., 0., 0.] + [i] + list(action) # Create the row: 60 observation columns + 3 buffer columns + 1 timestamp + 9 action columns
                            data.append(row) # data now in csv format
                        for iterator in range(len(list_of_subgoals)):
                            subgoal_index = list_of_subgoals[iterator]
                            actual_subgoals[iterator] += data[subgoal_index][:60] # Add the 
                    elif(subgoal_format == "image_embedding"):
                        video_path = f"{base_directory}/{directory_for_subgoals}/camera_{camera_number}.avi"
                        cap = cv2.VideoCapture(video_path)
                        for subgoal_index in list_of_subgoals:
                            cap.set(cv2.CAP_PROP_POS_FRAMES, subgoal_index)
                            ret, frame = cap.read()
                            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                            preprocessed_image = transforms(Image.fromarray(frame.astype(np.uint8))).reshape(-1, 3, 224, 224)
                            preprocessed_image = preprocessed_image.to(device)
                            with torch.no_grad():
                                subgoal_embedding = vip(preprocessed_image * 255.0)
                            actual_subgoals.append(subgoal_embedding.cpu().tolist()[0])
                        cap.release()

                    attention_mask = [False for _ in range(len(actual_subgoals))] # Use actual subgoals as input to the encoder

                    if len(actual_subgoals) < subgoal_seq_length: # Zero Padding 
                        padding = []
                        for i in range(len(actual_subgoals[0])):
                            padding.append(0)
                        for i in range(subgoal_seq_length - len(actual_subgoals)):
                            actual_subgoals.append(padding)
                            attention_mask.append(True) # True in mask means it is zero padding and no attention for this

                    actual_subgoals = torch.tensor( [actual_subgoals] , dtype=torch.float32).to(device) # extra [] to keep batch size 1
                    attention_mask = torch.tensor( [attention_mask] ).to(device)

                    Buffer = [[] for _ in range(max_steps + action_chunking)]  # Initialize buffer correctly
                    for i in range(max_steps):  # i is the current timestamp
                        timestamp = torch.tensor([i]).to(device)
                        if i % 100 == 0:
                            print(f"Timestamp: {i}/{max_steps}", flush=True)
                        state = robot_env.get_current_state(state_space)
                        state_tensor = torch.tensor([state], dtype=torch.float32).to(device)
                        with torch.no_grad(): 
                            action = model(actual_subgoals, state_tensor, attention_mask, timestamp, inference=True)  # Model predicts action chunks
                        action = action.cpu().numpy().flatten()  # Convert to numpy and flatten
                        action = action.reshape(action_chunking, output_dimension // action_chunking)  # Reshape into action chunks

                        for j in range(action_chunking): # Add the action chunks to the buffer
                            Buffer[i + j].append(action[j])
                        weights = np.exp(-temporal_ensemble * np.arange(len(Buffer[i])))  # Perform temporal ensemble: weighted average of the actions
                        weights /= weights.sum()  # Normalize weights
                        current_action = np.sum([w * a for w, a in zip(weights, Buffer[i])], axis=0)
                        current_action = current_action.tolist()  # Convert to list before passing to `step`
                        robot_env.step(current_action) # It accepts both list and np array

                    return list_of_subgoals # This is the list of subgoals 

                print(f"Evaluating subgoals from {directory_for_subgoals}, iteration number {iteration_number}...", flush=True)
                list_of_subgoals = robot_inference(directory_for_subgoals)

                graphs_path = f"./Graphs/{saving_formatter}/subgoals_{directory_for_subgoals}/itr_{iteration_number}"
                video_path = f"./Evaluation/{saving_formatter}/subgoals_{directory_for_subgoals}"
                os.makedirs(graphs_path, exist_ok=True) # Directory to save heat maps
                os.makedirs(video_path, exist_ok=True) # Directory to save Evaluation Videos

                video_filename = f"{video_path}/{iteration_number}.mp4"
                video_filename_in_hand = f"{video_path}/{iteration_number}_in_hand.mp4"
                robot_env.save_video(video_filename, video_filename_in_hand)
                heatmap_data = np.array(model.attention_weights).squeeze()
                if(heatmap_data.ndim ==3):
                    heatmap_data = np.expand_dims(heatmap_data, axis=2) # Extra dimension for number of heads
                for i in range(num_decoder_layers):  # Loop over decoder layers
                    for head in range(heatmap_data[i].shape[1]):  # Loop over attention heads (8 heads)
                        plt.figure()  # Create a new figure for each plot
                        plt.imshow(heatmap_data[i][:, head, :], cmap='hot_r', aspect='auto', interpolation='nearest')  # Select data for the specific head
                        plt.colorbar()  # Show color scale
                        plt.title(f'Heatmap of Attention Weights\Layer {i+1}, Attention Head {head+1}')
                        plt.xlabel('Subgoal Number')
                        plt.ylabel('Timestep')
                        y_ticks = np.arange(0, heatmap_data[i].shape[0], 100)  # Every 100 indices
                        plt.yticks(ticks=y_ticks, labels=y_ticks)
                        x_ticks = np.arange(heatmap_data[i].shape[2])  # Tick positions for subgoals
                        x_labels = x_ticks + 1  # Adjust labels to start at 1
                        plt.xticks(ticks=x_ticks, labels=x_labels, rotation=90, ha='right')
                        for subgoal in list_of_subgoals: # Plot horizontal lines for the subgoals
                            plt.axhline(y=subgoal, color='blue', linestyle='--', linewidth=0.8)
                        filename_base = f"{graphs_path}/decoder_layer{i+1}_head{head+1}"
                        plt.savefig(f"{filename_base}.png", dpi=300, bbox_inches='tight')
                        plt.close()  # Close the figure to free memory and prevent overlap
                        with open(f"{filename_base}.pkl", 'wb') as f:
                            pickle.dump({'subgoals': list_of_subgoals,'heatmap': heatmap_data[i][:, head, :]}, f)